package com.yeziji.codegen;

import cn.hutool.core.util.StrUtil;
import com.alibaba.druid.pool.DruidDataSource;
import com.mybatisflex.codegen.Generator;
import com.mybatisflex.codegen.config.GlobalConfig;
import com.mybatisflex.core.exception.MybatisFlexException;
import com.yeziji.codegen.base.ProjectParams;
import com.yeziji.codegen.utils.ProjectUtils;
import com.yeziji.common.CommonEntity;
import com.yeziji.common.IServiceImpl;
import com.yeziji.constant.VariousStrPool;
import com.yeziji.utils.DateUtils;
import com.yeziji.utils.expansion.Lists2;
import com.yeziji.utils.expansion.Str2;
import lombok.extern.slf4j.Slf4j;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * 代码生成器
 * @author hwy
 */
@Slf4j
public class DynamicGenerator {
    private static final String JDBC_URL = "jdbc:@{DB_TYPE}://@{JDBC_URL}:@{PORT}/@{DB_NAME}";
    private static final String USER_NAME = "root";
    private static final String DEV_PWD = "a96548854";
    private static String PROJECT_PATH_FORMAT = "com.yeziji.@{PROJECT_NAME}.business.%s";
    private static String PROJECT_DIR = String.format("%s\\%s\\src\\main\\java", System.getProperty(VariousStrPool.System.USER_DIR), "@{PROJECT_DIR}");
    private static String MAPPER_PROJECT_DIR = String.format("%s\\%s\\src\\main\\resources\\mapper", System.getProperty(VariousStrPool.System.USER_DIR), "@{PROJECT_DIR}");
    private static final String DEFAULT_PORT = "3306";
    private static final String DEFAULT_JDBC_TYPE = "mysql";
    private static final String DEFAULT_JDBC_URL = "127.0.0.1";
    // 通用权限表：system_config,system_permission,system_role,system_role_permission_relation,system_user_role_relation

    /**
     * 生成数据库代码
     * <a href="https://mybatis-flex.com/zh/others/codegen.html">MyBatis-Flex 代码生成器 demo</a>
     */
    public static void main(String[] args) {
        // 1. 获取指定服务
        String generateService = chooseService();

        // 2. 用输入服务的根包名
        PROJECT_DIR = PROJECT_DIR.replace("@{PROJECT_DIR}", generateService);
        log.info("指定 PROJECT DIR ====> {}", PROJECT_DIR);
        MAPPER_PROJECT_DIR = MAPPER_PROJECT_DIR.replace("@{PROJECT_DIR}", generateService);
        log.info("指定 MAPPER_PROJECT_DIR ====> {}", MAPPER_PROJECT_DIR);

        // 3. 命中根目录包名
        if (generateService.endsWith("service")) {
            String projectPath = getPathByService(generateService);
            System.out.println("service 模块 == 自动命名为: " + projectPath);
            PROJECT_PATH_FORMAT = PROJECT_PATH_FORMAT.replace("@{PROJECT_NAME}", projectPath);
        } else {
            String projectPath = scanner("非 service 模块 >>> 请自定义命名");
            PROJECT_PATH_FORMAT = PROJECT_PATH_FORMAT.replace("@{PROJECT_NAME}", projectPath);
        }
        log.info("指定 PROJECT_PATH_FORMAT ====> {}", PROJECT_PATH_FORMAT);

        // 4. 指定数据源
        DruidDataSource dataSource = new DruidDataSource();
        // 自定义连接地址
        String dbPort = scannerOrElse(String.format("自定义端口（0：默认值: [%s]）", DEFAULT_PORT), DEFAULT_PORT);
        String dbType = scannerOrElse(String.format("自定义数据库类型（0：默认值: [%s]）", DEFAULT_JDBC_TYPE), DEFAULT_JDBC_TYPE);
        String dbConnectUrl = scannerOrElse(String.format("自定义连接地址（0：默认值: [%s]）", DEFAULT_JDBC_URL), DEFAULT_JDBC_URL);
        String dbName;
        if (generateService.endsWith("service")) {
            String serviceDbName = getDbNameByService(generateService);
            dbName = scannerOrElse(String.format("自定义指定数据库（0：默认值: [%s]）", serviceDbName), serviceDbName);
        } else {
            dbName = scanner("自定义指定数据库");
        }
        String dbUrl = JDBC_URL.replace("@{DB_TYPE}", dbType)
                .replace("@{JDBC_URL}", dbConnectUrl)
                .replace("@{PORT}", dbPort)
                .replace("@{DB_NAME}", dbName);
        // 自定义用户名和密码
        String dbUsername = scannerOrElse(String.format("自定义账号（0：默认值: [%s]）", USER_NAME), USER_NAME);
        String dbPwd = scannerOrElse(String.format("自定义密码（0：默认值: [%s]）", DEV_PWD), DEV_PWD);

        dataSource.setUrl(dbUrl);
        dataSource.setUsername(dbUsername);
        dataSource.setPassword(dbPwd);
        log.info("指定 JDBC_URL ====> {}", dbUrl);
        log.info("指定 JDBC_USERNAME ====> {}", dbUsername);
        log.info("指定 JDBC_PWD ====> {}", dbPwd);
        Connection connection = null;
        try {
            long connectStartTime = System.currentTimeMillis();
            log.info("===== 开始测试连接 =====");
            connection = DriverManager.getConnection(dbUrl, dbUsername, dbPwd);
            log.info("===== 连接成功, 耗时: {} =====", DateUtils.formatTime(System.currentTimeMillis() - connectStartTime));
        } catch (SQLException e) {
            log.error("===== 连接失败，请确认信息无误：url={}, username={}, password={} =====", dbUrl, dbUsername, dbPwd, e);
            return;
        } finally {
            if (connection != null) {
                try {
                    log.info("===== 自动关闭测试连接 =====");
                    connection.close();
                    log.info("===== 自动关闭测试连接成功 =====");
                } catch (SQLException e) {
                    log.error("===== 数据库连接关闭异常 =====");
                }
            }
        }
        // 5. 生成代码
        String model = scanner("选择需要生成的模式：0 单模块模式， 1 多模块模式");
        if ("0".equals(model)) {
            generate(scanner("模块名"), dataSource);
        } else if ("1".equals(model)) {
            Set<String> moduleNames = Arrays.stream(scanner("多个模块(输入数据库表名); 用英文逗号隔开; 下划线会自动转驼峰").split(","))
                    .collect(Collectors.toSet());
            log.info("moduleNames -> {}", moduleNames);
            String scanner = scanner("是否忽略第一个 _ 的名称? no: 0, yes: 1");
            boolean ignoreFirst = "1".equals(scanner);
            moduleNames.forEach(tableName -> {
                String moduleName = tableName;
                if (ignoreFirst) {
                    moduleName = moduleName.substring(moduleName.indexOf("_") + 1);
                }
                String underlineCase = StrUtil.toCamelCase(moduleName);
                generate(tableName, underlineCase, dataSource);
            });
        }
    }

    /**
     * 选择生成代码的服务
     *
     * @return {@link String} 服务名称
     */
    public static String chooseService() {
        List<ProjectParams> projects = ProjectUtils.getProjects();
        if (Lists2.isEmpty(projects)) {
            throw new MybatisFlexException("扫描不到任何模块在当前项目中");
        }

        // 提示输入
        System.out.println("请选择指定服务的序号(如：1)【只能选择一个服务】");
        Map<Integer, String> folderMap =
                IntStream.range(0, projects.size())
                        .boxed()
                        .collect(Collectors.toMap(i -> i + 1, i -> projects.get(i).getProjectName()));
        for (Map.Entry<Integer, String> entry : folderMap.entrySet()) {
            System.out.println(entry.getKey() + ". " + entry.getValue());
        }

        // 选择服务
        Scanner scanner = new Scanner(System.in);
        if (scanner.hasNext()) {
            String ipt = scanner.next();
            if (!ipt.isBlank()) {
                int i = Integer.parseInt(ipt);
                return Optional.ofNullable(folderMap.get(i)).orElseThrow(() -> new MybatisFlexException("服务不存在"));
            }
        }
        throw new MybatisFlexException("请选择需要生成代码的服务");
    }

    /**
     * 根据 service 获取指定的目录名
     *
     * @param service 服务
     * @return {@link String} 目录名
     */
    public static String getPathByService(String service) {
        if (service.startsWith("yzj") && service.endsWith("service")) {
            return Str2.toCamelCase(Str2.replaceAllToEmpty(Str2.replaceAllToEmpty(service, "yzj-"), "-service"), '-', '_');
        }
        throw new MybatisFlexException("获取目录名失败: " + service);
    }

    /**
     * 根据 service 获取指定的数据库名
     *
     * @param service 服务
     * @return {@link String} 数据库名
     */
    public static String getDbNameByService(String service) {
        switch (service) {
            case "yzj-website-service":
                return "yzj_website";
            case "yzj-crawler-service":
                return "yzj_crawler";
            case "yzj-pay-service":
                return "yzj_pay";
            default:
                throw new MybatisFlexException("模块目录不存在");
        }
    }

    /**
     * 获取模板 tpl
     *
     * @param tplName 模板名称
     * @return {@link String} 模板路径
     */
    private static String getTemplatesTplDir(String tplName) {
        return String.format("%s\\yzj-code-generator\\templates\\enjoy\\%s.tpl", System.getProperty("user.dir"), tplName);
    }

    /**
     * 生成代码
     *
     * @param tableName  数据表名
     * @param moduleName 模块名
     * @param dataSource 生成的数据库
     */
    private static void generate(String tableName, String moduleName, DruidDataSource dataSource) {
        log.info("开始生成 {} 模块", moduleName);
        GlobalConfig globalConfig = createGlobalConfigUseStyle(tableName, moduleName);
        //通过 datasource 和 globalConfig 创建代码生成器
        Generator generator = new Generator(dataSource, globalConfig);
        //生成代码
        generator.generate();
        log.info("模块 {} 生成完毕", moduleName);
    }


    /**
     * 用户输入提示信息
     *
     * @param tips 提示信息
     * @return {@link String} 返回提示内容
     */
    private static String scanner(String tips) {
        Scanner scanner = new Scanner(System.in);
        System.out.println("请输入" + tips + "：");
        if (scanner.hasNext()) {
            String ipt = scanner.nextLine();
            if (!ipt.isBlank()) {
                return ipt;
            }
        }
        throw new MybatisFlexException("请输入正确的" + tips + "！");
    }

    private static String scannerOrElse(String tips, String orElse) {
        try {
            String scanner = scanner(tips);
            if ("0".equals(scanner)) {
                return orElse;
            }
            return scanner;
        } catch (Exception e) {
            log.warn("{} ====> 默认值: {}", tips, orElse);
            return orElse;
        }
    }

    /**
     * 生成模块配置
     *
     * @param tableName  数据表名称
     * @param moduleName 模块名称
     * @return {@link GlobalConfig} 全局配置
     */
    public static GlobalConfig createGlobalConfigUseStyle(String tableName, String moduleName) {
        //创建配置内容
        GlobalConfig globalConfig = new GlobalConfig();

        //设置自定义模板
        globalConfig.getTemplateConfig()
                .setEntity(getTemplatesTplDir("entity"))
                .setTableDef(getTemplatesTplDir("tableDef"));

        //设置根包
        globalConfig.getPackageConfig()
                .setSourceDir(PROJECT_DIR)
                .setMapperXmlPath(MAPPER_PROJECT_DIR)
                .setBasePackage(String.format(PROJECT_PATH_FORMAT, moduleName));

        // 设置生成 table
        if (tableName != null) {
            globalConfig.getStrategyConfig()
                    .setIgnoreColumns("id", "is_delete", "create_time", "update_time")
                    .setGenerateTable(tableName);
        } else {
            String filterMode = scanner("表名筛选模式：0. 筛选全表名, 1. 筛选表名前缀, 2. 排除全表名");
            String[] tablePrefix = null;
            Set<String> tables = new HashSet<>();
            switch (filterMode) {
                case "0":
                case "2":
                    tables =
                            Arrays.stream(scanner("表名, 多个表则使用英文逗号分割").split(","))
                                    .map(String::trim)
                                    .collect(Collectors.toSet());
                    break;
                case "1":
                    tablePrefix =
                            Arrays.stream(scanner("表名, 多个表则使用英文逗号分割").split(","))
                                    .map(String::trim)
                                    .distinct()
                                    .toArray(String[]::new);
                    break;
            }
            globalConfig.getStrategyConfig()
                    .setIgnoreColumns("id", "is_delete", "create_time", "update_time");
            switch (filterMode) {
                case "1":
                    globalConfig.getStrategyConfig().setTablePrefix(tablePrefix);
                    break;
                case "0":
                    globalConfig.getStrategyConfig().setGenerateTables(tables);
                    break;
                case "2":
                    globalConfig.getStrategyConfig().setUnGenerateTables(tables);
                    break;
            }
        }

        //设置生成 entity
        globalConfig.enableEntity()
                .setWithLombok(true)
                .setClassSuffix("Entity")
                .setSuperClass(CommonEntity.class);

        // 设置生成 mapper
        globalConfig.enableMapper()
                .setMapperAnnotation(true);

        // 设置生成 mapper xml
        globalConfig.enableMapperXml();

        // 设置生成 service
        globalConfig.enableService();

        // 设置生成 service impl
        globalConfig.enableServiceImpl()
                .setSuperClass(IServiceImpl.class);

        // 设置生成 javadoc
        globalConfig.getJavadocConfig()
                .setAuthor("system");

        // 生成 tableDef
        globalConfig.enableTableDef()
                .setOverwriteEnable(true);

        return globalConfig;
    }

    /**
     * 生成代码
     *
     * @param moduleName 模块名
     * @param dataSource 生成的数据库
     */
    private static void generate(String moduleName, DruidDataSource dataSource) {
        generate(null, moduleName, dataSource);
    }
}
